# yellowbrick.regressor.prediction_error
# Comparison of the predicted vs. actual values for regression problems
#
# Author:   Rebecca Bilbro
# Author:   Benjamin Bengfort
# Created:  Fri Jun 03 10:30:36 2016 -0700
#
# Copyright (C) 2016 The scikit-yb developers
# For license information, see LICENSE.txt
#
# ID: prediction_error.py []  $

"""
Comparison of the predicted vs. actual values for regression problems
"""

##########################################################################
## Imports
##########################################################################

from yellowbrick.style.palettes import LINE_COLOR
from yellowbrick.exceptions import YellowbrickValueError
from yellowbrick.bestfit import draw_best_fit, draw_identity_line
from yellowbrick.regressor.base import RegressionScoreVisualizer


## Packages for export
__all__ = ["PredictionError", "prediction_error"]


##########################################################################
## Prediction Error Plots
##########################################################################


class PredictionError(RegressionScoreVisualizer):
    """
    The prediction error visualizer plots the actual targets from the dataset
    against the predicted values generated by our model(s). This visualizer is
    used to detect noise or heteroscedasticity along a range of the target
    domain.

    Parameters
    ----------

    estimator : a Scikit-Learn regressor
        Should be an instance of a regressor, otherwise will raise a
        YellowbrickTypeError exception on instantiation.
        If the estimator is not fitted, it is fit when the visualizer is fitted,
        unless otherwise specified by ``is_fitted``.

    ax : matplotlib Axes, default: None
        The axes to plot the figure on. If None is passed in the current axes
        will be used (or generated if required).

    shared_limits : bool, default: True
        If shared_limits is True, the range of the X and Y axis limits will
        be identical, creating a square graphic with a true 45 degree line.
        In this form, it is easier to diagnose under- or over- prediction,
        though the figure will become more sparse. To localize points, set
        shared_limits to False, but note that this will distort the figure
        and should be accounted for during analysis.

    bestfit : bool, default: True
        Draw a linear best fit line to estimate the correlation between the
        predicted and measured value of the target variable. The color of
        the bestfit line is determined by the ``line_color`` argument.

    identity : bool, default: True
        Draw the 45 degree identity line, y=x in order to better show the
        relationship or pattern of the residuals. E.g. to estimate if the
        model is over- or under- estimating the given values. The color of the
        identity line is a muted version of the ``line_color`` argument.

    alpha : float, default: 0.75
        Specify a transparency where 1 is completely opaque and 0 is completely
        transparent. This property makes densely clustered points more visible.

    is_fitted : bool or str, default='auto'
        Specify if the wrapped estimator is already fitted. If False, the estimator
        will be fit when the visualizer is fit, otherwise, the estimator will not be
        modified. If 'auto' (default), a helper method will check if the estimator
        is fitted before fitting it again.

    kwargs : dict
        Keyword arguments that are passed to the base class and may influence
        the visualization as defined in other Visualizers.

    Attributes
    ----------

    score_ : float
        The R^2 score that specifies the goodness of fit of the underlying
        regression model to the test data.

    Examples
    --------

    >>> from yellowbrick.regressor import PredictionError
    >>> from sklearn.linear_model import Lasso
    >>> model = PredictionError(Lasso())
    >>> model.fit(X_train, y_train)
    >>> model.score(X_test, y_test)
    >>> model.show()

    Notes
    -----

    PredictionError is a ScoreVisualizer, meaning that it wraps a model and
    its primary entry point is the `score()` method.
    """

    def __init__(
        self,
        estimator,
        ax=None,
        shared_limits=True,
        bestfit=True,
        identity=True,
        alpha=0.75,
        is_fitted="auto",
        **kwargs
    ):
        # Whether or not to check if the model is already fitted
        self.is_fitted = is_fitted

        # Initialize the visualizer
        super(PredictionError, self).__init__(estimator, ax=ax, **kwargs)

        # Visual arguments
        self.colors = {
            "point": kwargs.pop("point_color", None),
            "line": kwargs.pop("line_color", LINE_COLOR),
        }

        # Drawing arguments
        self.shared_limits = shared_limits
        self.bestfit = bestfit
        self.identity = identity
        self.alpha = alpha

    def score(self, X, y, **kwargs):
        """
        The score function is the hook for visual interaction. Pass in test
        data and the visualizer will create predictions on the data and
        evaluate them with respect to the test values. The evaluation will
        then be passed to draw() and the result of the estimator score will
        be returned.

        Parameters
        ----------
        X : array-like
            X (also X_test) are the dependent variables of test set to predict

        y : array-like
            y (also y_test) is the independent actual variables to score against

        Returns
        -------
        score : float
        """
        # super will set score_ on the visualizer
        super(PredictionError, self).score(X, y, **kwargs)

        y_pred = self.predict(X)
        self.draw(y, y_pred)

        return self.score_

    def draw(self, y, y_pred):
        """
        Parameters
        ----------
        y : ndarray or Series of length n
            An array or series of target or class values

        y_pred : ndarray or Series of length n
            An array or series of predicted target values

        Returns
        -------
        ax : matplotlib Axes
            The axis with the plotted figure
        """
        label = "$R^2 = {:0.3f}$".format(self.score_)
        self.ax.scatter(
            y, y_pred, c=self.colors["point"], alpha=self.alpha, label=label
        )

        # TODO If score happens inside a loop, draw gets called multiple times.
        # Ideally we'd want the best fit line to be drawn only once
        if self.bestfit:
            draw_best_fit(
                y,
                y_pred,
                self.ax,
                "linear",
                ls="--",
                lw=2,
                c=self.colors["line"],
                label="best fit",
            )

        # Set the axes limits based on the range of X and Y data
        # NOTE: shared_limits will be accounted for in finalize()
        # TODO: do better than add one for really small residuals
        self.ax.set_xlim(y.min() - 1, y.max() + 1)
        self.ax.set_ylim(y_pred.min() - 1, y_pred.max() + 1)

        return self.ax

    def finalize(self, **kwargs):
        """
        Finalizes the figure by ensuring the aspect ratio is correct and adding
        the identity line for comparison. Also adds a title, axis labels, and
        the legend.

        Parameters
        ----------
        kwargs: generic keyword arguments.

        Notes
        -----
        Generally this method is called from show and not directly by the user.
        """
        # Set the title on the plot
        self.set_title("Prediction Error for {}".format(self.name))

        # Square the axes to ensure a 45 degree line
        if self.shared_limits:
            # Get the current limits
            ylim = self.ax.get_ylim()
            xlim = self.ax.get_xlim()

            # Find the range that captures all data
            bounds = (min(ylim[0], xlim[0]), max(ylim[1], xlim[1]))

            # Reset the limits
            self.ax.set_xlim(bounds)
            self.ax.set_ylim(bounds)

            # Ensure the aspect ratio is square
            self.ax.set_aspect("equal", adjustable="box")

        # Draw the 45 degree line
        if self.identity:
            draw_identity_line(
                ax=self.ax,
                ls="--",
                lw=2,
                c=self.colors["line"],
                alpha=0.5,
                label="identity",
            )

        # Set the axes labels
        self.ax.set_ylabel(r"$\hat{y}$")
        self.ax.set_xlabel(r"$y$")

        # Set the legend
        # Note: it would be nice to be able to use the manual_legend utility
        # here, since if the user sets a low alpha value, the R2 color in the
        # legend will also become more translucent. Unfortunately this is a
        # bit tricky because adding a manual legend here would override the
        # best fit and 45 degree line legend components. In particular, the
        # best fit is plotted in draw because it depends on y and y_pred.
        self.ax.legend(loc="best", frameon=True)


##########################################################################
## Quick Method
##########################################################################


def prediction_error(
    estimator,
    X_train,
    y_train,
    X_test=None,
    y_test=None,
    ax=None,
    shared_limits=True,
    bestfit=True,
    identity=True,
    alpha=0.75,
    is_fitted="auto",
    show=True,
    **kwargs
):
    """Quickly plot a prediction error visualizer

    Plot the actual targets from the dataset against the
    predicted values generated by our model(s).

    This helper function is a quick wrapper to utilize the PredictionError
    ScoreVisualizer for one-off analysis.

    Parameters
    ----------
    estimator : the Scikit-Learn estimator (should be a regressor)
        Should be an instance of a regressor, otherwise will raise a
        YellowbrickTypeError exception on instantiation.
        If the estimator is not fitted, it is fit when the visualizer is fitted,
        unless otherwise specified by ``is_fitted``.

    X_train : ndarray or DataFrame of shape n x m
        A feature array of n instances with m features the model is trained on.
        Used to fit the visualizer and also to score the visualizer if test splits are
        not directly specified.

    y_train : ndarray or Series of length n
        An array or series of target or class values. Used to fit the visualizer and
        also to score the visualizer if test splits are not specified.

    X_test : ndarray or DataFrame of shape n x m, default: None
        An optional feature array of n instances with m features that the model
        is scored on if specified, using X_train as the training data.

    y_test : ndarray or Series of length n, default: None
        An optional array or series of target or class values that serve as actual
        labels for X_test for scoring purposes.

    ax : matplotlib Axes
        The axes to plot the figure on.

    shared_limits : bool, default: True
        If shared_limits is True, the range of the X and Y axis limits will
        be identical, creating a square graphic with a true 45 degree line.
        In this form, it is easier to diagnose under- or over- prediction,
        though the figure will become more sparse. To localize points, set
        shared_limits to False, but note that this will distort the figure
        and should be accounted for during analysis.

    bestfit : bool, default: True
        Draw a linear best fit line to estimate the correlation between the
        predicted and measured value of the target variable. The color of
        the bestfit line is determined by the ``line_color`` argument.

    identity: bool, default: True
        Draw the 45 degree identity line, y=x in order to better show the
        relationship or pattern of the residuals. E.g. to estimate if the
        model is over- or under- estimating the given values. The color of the
        identity line is a muted version of the ``line_color`` argument.

    alpha : float, default: 0.75
        Specify a transparency where 1 is completely opaque and 0 is completely
        transparent. This property makes densely clustered points more visible.

    is_fitted : bool or str, default='auto'
        Specify if the wrapped estimator is already fitted. If False, the estimator
        will be fit when the visualizer is fit, otherwise, the estimator will not be
        modified. If 'auto' (default), a helper method will check if the estimator
        is fitted before fitting it again.

    show: bool, default: True
        If True, calls ``show()``, which in turn calls ``plt.show()`` however you cannot
        call ``plt.savefig`` from this signature, nor ``clear_figure``. If False, simply
        calls ``finalize()``

    kwargs : dict
        Keyword arguments that are passed to the base class and may influence
        the visualization as defined in other Visualizers.

    Returns
    -------
    ax : matplotlib Axes
        Returns the axes that the prediction error plot was drawn on.
    """
    # Instantiate the visualizer
    visualizer = PredictionError(
        estimator,
        ax,
        shared_limits=shared_limits,
        bestfit=bestfit,
        identity=identity,
        alpha=alpha,
        is_fitted=is_fitted,
        **kwargs
    )

    visualizer.fit(X_train, y_train)

    # Scores the visualizer with X and y test if provided, X and y train if not
    if X_test is not None and y_test is not None:
        visualizer.score(X_test, y_test)
    elif X_test is not None or y_test is not None:
        raise YellowbrickValueError(
            "both X_test and y_test are required if one is specified"
        )
    else:
        visualizer.score(X_train, y_train)

    if show:
        visualizer.show()
    else:
        visualizer.finalize()

    # Return the axes object on the visualizer
    return visualizer
